{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47174023-c012-46fb-8687-893e880d7bf3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.7/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.10) or chardet (5.0.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n",
      "  RequestsDependencyWarning)\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import tqdm\n",
    "import random\n",
    "from typing import List\n",
    "import itertools as it\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import fmin_cobyla\n",
    "\n",
    "from mindquantum.core.operators import QubitOperator\n",
    "from mindquantum.core.operators import Hamiltonian\n",
    "from mindquantum.simulator import Simulator\n",
    "from mindquantum import H, X, Z, RY, UN, CNOT, Circuit\n",
    "import mindspore.context as context\n",
    "\n",
    "context.set_context(mode=context.PYNATIVE_MODE, device_target=\"CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c5bf5b6-80f5-425d-aa25-3bb2aca293e2",
   "metadata": {},
   "source": [
    "## 1. 数据集\n",
    "\n",
    "整个训练过程和机器学习一致，首先应准备数据集，然后使用 VQA 训练网络参数。\n",
    "\n",
    "根据论文，将论文 [1][Francesco Scala, etc. Quantum variational learning for entanglement witnessing] 中 `A. Exact witness computation` 识别出的纠缠态（共18个，对应图5中高于红线以上的态）对应的标签为 1，其余将可分离（bi-separable）状态（对应图5中橙色线）的前一半的标签设为 0，另外将为未识别出来的纠缠态的前一半再取 60% 的标签设为 0. 取一半的原因是因为数据具有对称性。\n",
    "\n",
    "此处将获取可分离态、获取识别纠缠态等用代码复现。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d153babc-8468-4a55-a647-441c626bd152",
   "metadata": {},
   "source": [
    "### 1.1 Bi-separable states\n",
    "\n",
    "通过逐一列举，获取 3-bits 系统中所有可以分为两个子系统的态。由于 REW 态是所有状态等权重混合，只有每个成分前面的正负号不同，对于3比特的 REW 态，共有8个成分，因此将每个成分前面的符号拿出来共有 2^8 = 256 种组合，因此用数值 0-255 表示。在代码中，负号用 1 代替，正号用 0 代替，这样可以形成01的二进制串，转换成10进制即为0-255之间的数值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ae8a30a5-bd71-438e-84b7-1429f3742a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_separable_states() -> List:\n",
    "    \"\"\"Get all 3-bits bi-separable states and represented by digit that\n",
    "    range from 0 to 255\"\"\"\n",
    "    pm1bit = list(it.product([-1, 1], repeat=2))\n",
    "    pm2bit = list(it.product([-1, 1], repeat=4))\n",
    "    pm3bit = np.array([np.kron(a, b) for a in pm1bit for b in pm2bit])\n",
    "    str1bit = ['0', '1']\n",
    "    str2bit = ['00', '01', '10', '11']\n",
    "    strs1 = [a + b for a in str1bit for b in str2bit]\n",
    "    strs2 = [b + a for a in str1bit for b in str2bit]\n",
    "    strs3 = [b[0] + a + b[1] for a in str1bit for b in str2bit]\n",
    "    strs4 = [b[1] + a + b[0] for a in str1bit for b in str2bit]\n",
    "    strs3bit = [strs1, strs2, strs3, strs4]\n",
    "    sep_digits = []\n",
    "    for strs in strs3bit:\n",
    "        for pm in pm3bit:\n",
    "            # sorted to make the string in the order of: |0>, |1>,..., |7>\n",
    "            pm_strs = sorted(zip(pm, strs), key=lambda x: x[1])\n",
    "            # map {-1, 1} to {0, 1}\n",
    "            pm, _ = zip(*pm_strs)\n",
    "            pm = np.array(pm)\n",
    "            pm[pm == -1] = 0  # change -1 to 0\n",
    "            i = (np.array(pm) * np.array([128, 64, 32, 16, 8, 4, 2, 1])).sum()\n",
    "            sep_digits.append(i)\n",
    "    sep_digits = sorted(list(set(sep_digits)))\n",
    "    return sep_digits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6717982e-b9dd-4f12-a506-4d32d3eeead7",
   "metadata": {},
   "source": [
    "### 1.2 Recognized entangled states\n",
    "\n",
    "取参考态 $|H>=[0, 0, 0, 0, 0, 1, 1, 0]$，计算各个态与参考态的激活值，激活值也相当于 overlap 或 fidelity, 其值为 $Tr[\\rho |H><H|]$ 。如果高于阈值（文中取 0.5），即为识别出来的可分离态。由于激活值的两个态只是 REW 前面的符号不同，因此使用异或运算等计算激活值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ff6a035-40a7-4810-ab72-e4a7b5cb8488",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_overlap(fx: List, ref: List, n=3) -> float:\n",
    "    \"\"\"Get overlap between `fx` and `ref`.\n",
    "\n",
    "    Args:\n",
    "        fx: The input state to calculate the overlap (fidelity).\n",
    "        ref: The reference state.\n",
    "        n: number of qubits.\n",
    "\n",
    "    Return:\n",
    "        act: The overlap value, which is called overlap or fidelity as well.\n",
    "    \"\"\"\n",
    "    xor = np.array(list(map(lambda x, y: x ^ y, fx, ref)))\n",
    "    act = np.abs(((-1)**xor).sum()) / 2**n\n",
    "    return act\n",
    "\n",
    "\n",
    "def get_recognized_entangled_states():\n",
    "    \"\"\"Get recognized entangled states.\n",
    "    \"\"\"\n",
    "    n = 3                                     # number of qubits\n",
    "    ref = [0, 0, 0, 0, 0, 1, 1, 0]            # reference states\n",
    "    fxs = list(it.product([0, 1], repeat=8))  # all states\n",
    "    rec_ent_states = []\n",
    "    for fx in fxs:\n",
    "        v = get_overlap(fx, ref, n)\n",
    "        if v > 0.5:\n",
    "            rec_ent_states.append(\n",
    "                (np.array(fx) * np.array([128, 64, 32, 16, 8, 4, 2, 1]))\n",
    "                .sum()\n",
    "            )\n",
    "    return rec_ent_states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "983d4152-418a-4784-b454-20e2946209a3",
   "metadata": {},
   "source": [
    "### 1.3 训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1cb8637f-7da3-4bcf-8193-80d25a8cebe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_training_data():\n",
    "    \"\"\"Prepare the training data.\n",
    "    \"\"\"\n",
    "    # bi-separable states\n",
    "    sep_states = get_separable_states()\n",
    "    # recognized entangled states\n",
    "    rec_ent_states = get_recognized_entangled_states()\n",
    "    # unrecognized entangled states\n",
    "    unrec_ent_states = list(\n",
    "        set(range(256)) - set(sep_states) - set(rec_ent_states)\n",
    "    )\n",
    "    # split the train data according to the proportion in paper.\n",
    "    half_fun = lambda x: x <= 127\n",
    "    x0 = list(filter(half_fun, sep_states))\n",
    "    x0 += random.sample(list(filter(half_fun, unrec_ent_states)),\n",
    "                        int(256 / 2 * 0.6))\n",
    "    y0 = [0] * len(x0)\n",
    "    x1 = rec_ent_states\n",
    "    y1 = [1] * len(x1)\n",
    "    x = x0 + x1\n",
    "    y = y0 + y1\n",
    "    return (x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96e7ad84-b24c-43d8-9721-0046faed9732",
   "metadata": {},
   "source": [
    "## 2. 量子线路\n",
    "\n",
    "量子线路主要包括编码器（Encoder）线路和 ansatz，编码器及将 REW 态转换成量子线路实现，使用论文[2][M. Rossi, M. Huber, D. Bruß, et al. Quantum hypergraph states] 提出的方法实现编码器。 ansatz 比较简单，直接按照论文[1] Fig.2 实现即可。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebcb728c-33f3-4de8-91cd-631a0d3afb0b",
   "metadata": {},
   "source": [
    "### 2.1 Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d01a8e5-f11b-4e42-9dc5-5f873eece8f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def int2signs(i: int) -> List:\n",
    "    \"\"\"Convert a integral to a 8-bits sign array, such as 11=|0000 0111>,\n",
    "    map 1 to -1, map 0 to 1, so the result is [1, 1, 1, 1, 1, -1, -1, -1];\n",
    "\n",
    "    Args:\n",
    "        i: the input integral.\n",
    "    Return:\n",
    "        signs: the sign array corresponding to the `i`.\n",
    "    \"\"\"\n",
    "    signs = []\n",
    "    for _ in range(8):\n",
    "        tmp = -1 if i % 2 else 1\n",
    "        signs.insert(0, tmp)\n",
    "        i = i // 2\n",
    "    return signs\n",
    "\n",
    "\n",
    "def get_operation(signs: List) -> List:\n",
    "    \"\"\"Generate the operations according to `signs`, these rules refers to \n",
    "    paper: M. Rossi, M. Huber, D. Bruß, et al. Quantum hypergraph states.\n",
    "\n",
    "    Args:\n",
    "        signs: the signs of REW states, a array with 8 bits and composed of -1 or 1.\n",
    "\n",
    "    Return:\n",
    "        ops: The qubits that Z(controlled-Z) gates apply to.\n",
    "    \"\"\"\n",
    "    ops = []\n",
    "    signs = np.array(signs)\n",
    "    if signs[0] == -1:\n",
    "        signs = -signs\n",
    "    if signs[1] == -1:\n",
    "        signs[[1, 3, 5, 7]] *= -1\n",
    "        ops.append((0,))\n",
    "    if signs[2] == -1:\n",
    "        signs[[2, 3, 6, 7]] *= -1\n",
    "        ops.append((1,))\n",
    "    if signs[4] == -1:\n",
    "        signs[[4, 5, 6, 7]] *= -1\n",
    "        ops.append((2,))\n",
    "    if signs[3] == -1:\n",
    "        signs[[3, 7]] *= -1\n",
    "        ops.append((0, 1))\n",
    "    if signs[5] == -1:\n",
    "        signs[[5, 7]] *= -1\n",
    "        ops.append((0, 2))\n",
    "    if signs[6] == -1:\n",
    "        signs[[6, 7]] *= -1\n",
    "        ops.append((1, 2))\n",
    "    if signs[7] == -1:\n",
    "        signs[7] *= -1\n",
    "        ops.append((0, 1, 2))\n",
    "    return ops\n",
    "\n",
    "\n",
    "def prepare_states_circuit(ops: List) -> Circuit:\n",
    "    \"\"\"Prepare REW states circuit.\n",
    "\n",
    "    Args:\n",
    "        ops: The operations that prepare the circuit.\n",
    "\n",
    "    Return:\n",
    "        cir: The quantum circuit to prepare the state corresponding to `ops`.\n",
    "    \"\"\"\n",
    "    cir = Circuit()\n",
    "    cir += UN(H, [0, 1, 2])\n",
    "    for op in ops:\n",
    "        cir += Z.on(op[0], op[1:])\n",
    "    return cir\n",
    "\n",
    "\n",
    "def get_encoder(i: int) -> Circuit:\n",
    "    \"\"\"Get the encoder circuit corresponding to `i`.\n",
    "\n",
    "    Args:\n",
    "        i: The data, which means the signs of a REW state.\n",
    "\n",
    "    Return:\n",
    "        cir: The circuit to prepare the state of determined by `i`.\n",
    "    \"\"\"\n",
    "    signs = int2signs(i)\n",
    "    ops = get_operation(signs)\n",
    "    cir = prepare_states_circuit(ops)\n",
    "    return cir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e8b8577-70dd-4cb5-b7e8-d8d1ad647df0",
   "metadata": {},
   "source": [
    "### 2.2 Ansatz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fbe38ce8-309c-4b85-b429-dadf92290452",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ansatz():\n",
    "    \"\"\"Prepare the ansatz proposed in paper: Francesco Scala et al. \n",
    "    Quantum variational learning for entanglement witnessing.\n",
    "    \"\"\"\n",
    "    ansatz = Circuit([\n",
    "        RY('theta0').on(0),\n",
    "        RY('theta1').on(1),\n",
    "        RY('theta2').on(2),\n",
    "        CNOT(1, 0),\n",
    "        CNOT(2, 0),\n",
    "        CNOT(2, 1),\n",
    "        RY('theta3').on(0),\n",
    "        RY('theta4').on(1),\n",
    "        RY('theta5').on(2),\n",
    "        CNOT(1, 0),\n",
    "        CNOT(2, 0),\n",
    "        CNOT(2, 1),\n",
    "        RY('theta6').on(0),\n",
    "        RY('theta7').on(1),\n",
    "        RY('theta8').on(2),\n",
    "    ])\n",
    "    return ansatz"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f5fbfa7-e1d0-4651-88cd-f3075c6eb5eb",
   "metadata": {},
   "source": [
    "### 2.3 Full circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "89bdde9e-3751-45dc-bb1e-d6c6b0e8f3bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_full_circuit(i: int) -> Circuit:\n",
    "    \"\"\"Get the full circuit with initial state is `i`.\n",
    "    \"\"\"\n",
    "    encoder = get_encoder(i)\n",
    "    ansatz = get_ansatz()\n",
    "    cir = encoder + ansatz\n",
    "    cir += X(3, [0, 1, 2])\n",
    "    return cir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fa375c8-54cd-4db2-966d-3fb6475ec2e3",
   "metadata": {},
   "source": [
    "## 3. 训练与验证\n",
    "\n",
    "方法一：使用梯度下降法（gradient descent）训练网络参数，由于使用单个样本训练最后容易受单个参数影响，因此使用多个数据同时计算梯度累加，减少误差。另外使得学习率不断变小，避免震荡。在计算损失时，由于bi-seperable的标签为 $0$，其余标签为 $1$，使用 $Z3$ 做哈密顿量测量第4个量子比特的期望值（取值范围为 $[-1, 1]$）。采用交叉熵损失，具体做法为在梯度求解中根据期望值求得的变量梯度计算交叉熵损失下的变量梯度。由于正负样本不平衡，对较少样本损失乘以一定权重（代码中取 $10$），为了避免梯度太大等造成误差，将梯度使用 `clip` 函数截断。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "37429018-e286-4d26-b12d-db61f2a99bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_gd():\n",
    "    n_qubits = 4     # number of qubits\n",
    "    batch_size = 16  # batch size\n",
    "    lr = 0.02        # learning rate\n",
    "    decay = 0.99     # learning rate decay\n",
    "    n_iter = 300     # number of iteration\n",
    "\n",
    "    sim = Simulator('projectq', n_qubits)\n",
    "    ham = Hamiltonian(QubitOperator('Z3'))\n",
    "    weight = np.array([0] * 9, dtype=np.float32)\n",
    "    x, y = prepare_training_data()\n",
    "\n",
    "    for _ in tqdm.tqdm(range(n_iter)):\n",
    "        grad = np.zeros_like(weight)\n",
    "        for _ in range(batch_size):\n",
    "            xi, yi = random.choice(list(zip(x, y)))\n",
    "            cir = get_full_circuit(xi)\n",
    "            grad_ops = sim.get_expectation_with_grad(ham, cir)\n",
    "            f, g = grad_ops(weight)\n",
    "            p = np.real((1.0 + f[0, 0]) / 2.)\n",
    "            g = np.real(g.ravel())\n",
    "            g0 = yi / 2 / p * g\n",
    "            g1 = -(1 - yi) / 2 / (1 - p) * g\n",
    "            g0 = np.clip(10.0 * g0, -20, 20)\n",
    "            g1 = np.clip(g1, -20, 20)\n",
    "            grad += g0 + g1\n",
    "        weight -= lr * grad\n",
    "        lr *= decay\n",
    "    return weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b543e62f-be92-4df5-a18c-7036c053f345",
   "metadata": {},
   "source": [
    "方法二：基于所给数据进行训练，按照论文中使用交叉熵作为损失函数，使用 COBYLA 优化器更新参数。调用 `scipy` 库的优化函数 `fmin_cobyla`。验证时将各个 REW 态输入，查看网络判别结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "74a94867-f0bc-4d42-8d74-4cb64d3ba8c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_cobyla():\n",
    "    \"\"\"Train the network with COBYLA optimizer to optimize the parameters.\n",
    "    \"\"\"\n",
    "    def objective(pr):\n",
    "        \"\"\"The objective function.\n",
    "        \"\"\"\n",
    "        v = 0.0\n",
    "        for xi, yi in zip(x, y):\n",
    "            cir = get_full_circuit(xi)\n",
    "            qs = cir.get_qs(pr=pr)\n",
    "            pi = np.abs(qs[-1]**2).sum()\n",
    "            v -= yi * np.log(pi) + (1 - yi) * np.log(1 - pi)\n",
    "        return v\n",
    "\n",
    "    x, y = prepare_training_data()\n",
    "    print(f\"Finish preparing data, with {len(x)} items.\")\n",
    "    params = np.random.uniform(size=9)\n",
    "    print(\"Begin training, it may take several minutes on cpu.\")\n",
    "    best_pr = fmin_cobyla(objective, params, cons=[])\n",
    "    print(f\"Finish training, the best parameters are:\\n{best_pr}.\\n\")\n",
    "    return best_pr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85831e55-f589-4fa0-938d-664a782f78d6",
   "metadata": {},
   "source": [
    "验证和可视化代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a680cce0-3e0e-4ec7-b95f-599636dcb866",
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate_and_visualize(best_pr, save_path=\"fig5.png\"):\n",
    "    \"\"\"Validate the result and plot Fig. 5, which from the paper.\n",
    "\n",
    "    Args:\n",
    "        best_pr: The trained parameters.\n",
    "    \"\"\"\n",
    "    acts = []\n",
    "    for i in range(256):\n",
    "        cir = get_full_circuit(i)\n",
    "        qs = cir.get_qs(pr=best_pr)\n",
    "        acts.append(np.abs(qs[-1]**2).sum())\n",
    "    # seprable states\n",
    "    sep_states = get_separable_states()\n",
    "    sep_values = np.array(acts)[sep_states]\n",
    "    plt.figure(figsize=(12, 5))\n",
    "    plt.bar(range(256), acts, color='cyan', width=1.0)\n",
    "    plt.bar(sep_states, sep_values, color='orange', width=1.0)\n",
    "    plt.plot([0, 255], [0.5, 0.5], color='red')\n",
    "    plt.xlabel(\"Hypergraph states\")\n",
    "    plt.ylabel(\"Activation\")\n",
    "    plt.savefig(save_path)\n",
    "    print(f\"The result image has been save at: {save_path}\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f5292e-1e1f-46cf-8501-bf8ee7dc01f2",
   "metadata": {},
   "source": [
    "## 4. 运行验证\n",
    "\n",
    "对使用梯度下降法和COBYLA方法均进行验证。可以看到其输出结果与论文[1] Fig.5 基本一致。取阈值为 0.5，在 256 个态中，所有可分离态（64个）被正确识别，192个纠缠态中有 18 个纠缠态被正确识别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "32573542-e269-4a34-9137-d7c42272f3c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 300/300 [04:26<00:00,  1.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The result image has been save at: fig5.png\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAtAAAAE9CAYAAAAiZVVdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbIklEQVR4nO3dfZBsZ10n8O8vL4hAAN1cWM2LiW7EjcgGvBWgtNwgiAkCcVcWEtcXlCUoRMXFLeNqIaLWrmupCAIa3QgqIcTXvVtGomIQi1owN0sMJBi8BpRk0YS3gCKv/vaPOTd0JjN3+uTOmenu+XyqpqbP6WdO/2aeOd3ffvo551R3BwAAmM8xu10AAAAsEwEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARjhutwsY68QTT+zTTjttt8sAAGDFXXfdde/v7n3r1y9dgD7ttNNy8ODB3S4DAIAVV1V/s9F6UzgAAGAEARoAAEYQoAEAYAQBGgAARhCgAQBghMkCdFVdVlW3V9U7Nrm/quqlVXWoqm6oqkdNVQsAAGyXKUegX5Xk3CPcf16SM4avi5K8csJaAABgW0wWoLv7TUk+eIQm5yf5tV7zliQPrqovmKoeAADYDrs5B/qkJO+dWb51WAcAAAtrKQ4irKqLqupgVR284447drscAAD2sN0M0LclOWVm+eRh3T1096Xdvb+79+/bd4/LkQMAwI7ZzQB9IMm3DWfjeEySO7v7fbtYz2Rq+AIA2GtWMQcdN9WGq+q1Sc5JcmJV3ZrkR5McnyTd/YtJrkrypCSHknwsyXdMVQsAAGyXyQJ0d1+4xf2d5HlTPT4AAExhKQ4iBACARSFAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjDBpgK6qc6vq5qo6VFWXbHD/qVV1TVW9rapuqKonTVkPAAAcrckCdFUdm+TlSc5LcmaSC6vqzHXNfiTJld39yCQXJHnFVPUAAMB2mHIE+uwkh7r7lu7+ZJIrkpy/rk0neeBw+0FJ/t+E9QAAwFE7bsJtn5TkvTPLtyZ59Lo2L0ryh1X1PUnun+QJE9YDAABHbbcPIrwwyau6++QkT0ry61V1j5qq6qKqOlhVB++4444dLxIAAA6bMkDfluSUmeWTh3WznpXkyiTp7v+T5L5JTly/oe6+tLv3d/f+ffv2TVQuAABsbcoAfW2SM6rq9Kq6T9YOEjywrs3fJnl8klTVv85agDbEDADAwposQHf3p5NcnOTqJO/M2tk2bqyqF1fVU4dmL0jy7Kr6iySvTfLM7u6pagIAgKM15UGE6e6rkly1bt0LZ27flOSrpqwBAAC2024fRAgAAEtFgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhh0gBdVedW1c1VdaiqLtmkzdOr6qaqurGqLp+yHgAAOFrHTbXhqjo2ycuTfF2SW5NcW1UHuvummTZnJPmhJF/V3R+qqodMVQ8AAGyHuQJ0VZ2U5Itm23f3m7b4sbOTHOruW4ZtXJHk/CQ3zbR5dpKXd/eHhm3ePn/pAACw87YM0FX1U0mekbXg+5lhdSfZKkCflOS9M8u3Jnn0ujZfOjzGm5Mcm+RF3f36rcteXDV87yXdPgDAlHmjZm4va56ZZwT6G5M8rLs/MdHjn5HknCQnJ3lTVX1Fd394tlFVXZTkoiQ59dRTJygDAADmM89BhLckOf5ebPu2JKfMLJ88rJt1a5ID3f2p7n53kndlLVDfTXdf2t37u3v/vn377kUpAACwPeYZgf5Ykuur6g1J7hqF7u7v3eLnrk1yRlWdnrXgfEGSb17X5veSXJjkV6vqxKxN6bhlvtIBAGDnzROgDwxfo3T3p6vq4iRXZ21+82XdfWNVvTjJwe4+MNz3xKo6PL/6v3T3B8Y+FgAA7JTq3nr6dlXdJ8MBf0lu7u5PTVrVEezfv78PHjy4Ww+/pY0m3W/nRHwHEQIAU5syuyzTQYRVdV1371+/fp6zcJyT5NVJ3pO13/mUqvr2OU5jBwAAK2eeKRw/k+SJ3X1zklTVlyZ5bZKvnLKwVbRM77gAAO6t2rrJUpvnLBzHHw7PSdLd78q9OysHAAAsvXlGoA9W1a8k+Y1h+T8mWdxJyAAAMKF5AvR3J3leksOnrfuzJK+YrCIAAFhgWwbo4QqEPzt8AQDAnrZpgK6qK7v76VX19mxwzFt3P2LSypaAgwIBAI7esp2m90gj0N83fH/yThQCAADLYNOzcHT3+4abz+3uv5n9SvLcnSkPAAAWyzynsfu6Ddadt92FAADAMjjSHOjvztpI8xdX1Q0zd52Q5M1TF8b0lm2+ERvTj8DRcDzP6vB6sHOONAf68iR/kOS/JblkZv1Hu/uDk1YFAAALatMA3d13JrkzyYVJUlUPSXLfJA+oqgd099/uTIkAALA4tpwDXVVPqaq/SvLuJH+a5D1ZG5kGAIA9Z56DCH8iyWOSvKu7T0/y+CRvmbQqAABYUPME6E919weSHFNVx3T3NUn2T1wXe1jl7ge1AMAi8TrFlpfyTvLhqnpAkjcleU1V3Z7kH6ctCwAAFtM8I9DnJ/lYku9P8vokf53kKVMWBQAAi2qeEejnJHldd9+W5NUT1wPAOrt5nl7nCAa4p3lGoE9I8odV9WdVdXFVPXTqogAAYFFtOQLd3T+W5Meq6hFJnpHkT6vq1u5+wuTVLZrnPz+5/vq7Fq/ZoMm86+6tRd3WdlrUuhaVv9fq280+9v+1+vTxeIv6N1ukjHCknx+Vlc46K3nJS46ymu03zwj0Ybcn+bskH0jykGnKAQCAxVbdR57VVlXPTfL0JPuS/GaSK7v7ph2obUP79+/vgwcP7tbD381GcwM3ug79Rqe6ubdzCbfzOvfbua3ttKh1LSp/r9VnDjRT0sfjLerz7iJlhCNln3nz0yKoquu6+x6nb57nIMJTkjy/u6/f9qoAAGDJbBqgq+qB3f2RJD89LH/+7P3d/cGJawMAgIVzpBHoy5M8Ocl1WRtRXz/i/sUT1gUAAAtp0wDd3U8evp++c+UAAMBi2/IsHFX1hnnWAQDAXnCkOdD3TXK/JCdW1efls1M4HpjkpB2oDQAAFs6R5kA/J8nzk3xh1uZBHw7QH0nyC9OWBQAAi+lIc6B/PsnPV9X3dPfLdrAmAABYWPNcifCfq+rBhxeq6vOGi6sAAMCeM0+AfnZ3f/jwQnd/KMmzJ6sIAAAW2DwB+tiquusc0FV1bJL7TFcSAAAsrnku5f36JK+rql8alp+T5A+mKwkAABbXPAH6B5NclOS7huUbkvzLySoCAIAFtuUUju7+5yRvTfKeJGcn+dok75y2LAAAWExHupDKlya5cPh6f5LXJUl3P25nSgMAgMVzpCkcf5nkz5I8ubsPJUlVff+OVAUAAAvqSFM4/n2S9yW5pqp+uaoen89ejRAAAPakTQN0d/9ed1+Q5MuSXJO1y3o/pKpeWVVP3KH6AABgocxzEOE/dvfl3f2UJCcneVvWzswBAAB7zjwXUrlLd3+ouy/t7sdPVRAAACyyUQEaAAD2ukkDdFWdW1U3V9WhqrrkCO2+qaq6qvZPWQ8AABytyQJ0VR2b5OVJzktyZpILq+rMDdqdkOT7snaxFgAAWGhTjkCfneRQd9/S3Z9MckWS8zdo9+NJfirJxyesBQAAtsWUAfqkJO+dWb51WHeXqnpUklO6+/ePtKGquqiqDlbVwTvuuGP7KwUAgDnt2kGEVXVMkp9N8oKt2g5n/tjf3fv37ds3fXEAALCJKQP0bUlOmVk+eVh32AlJHp7kjVX1niSPSXLAgYQAACyyKQP0tUnOqKrTq+o+SS5IcuDwnd19Z3ef2N2ndfdpSd6S5KndfXDCmgAA4KhMFqC7+9NJLk5ydZJ3Jrmyu2+sqhdX1VOnelwAAJjScVNuvLuvSnLVunUv3KTtOVPWAgAA28GVCAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGCE43a7AI7g8lr7/s29u3UAAHvb4UySyCUxAg0AAKMI0AAAMIIADQAAIwjQAAAwgoMIpzL1ZHuT+QGA7bRT2WUFcosRaAAAGEGABgCAEQRoAAAYwRzoMaacu2NOMwCwTKae07zAc6aNQAMAwAhGoIG7HP4cZPHe6wOzZj6ztL/CLhCg94jJg9ECf8yyUkz1AXaC55qdM+Hrpzda0zGFAwAARhCgt8vldfd37AAAHNmS5idTONhVU08tMacXgKmnMpgqsfcI0NjxV4R+BMZYvjE/NqIfd4cADTCBnRrx2s5tG6UDmI850AAAMIIADQAAI5jCsYJW5WPSVfk97q29/vsDi2WvH5S9KnON93o/bhcj0AAAMIIRaGAark559Bbtb7ho9SwbV/eDlSFAs3O8eACwV3jDudJM4QAAgBGMQAN7ywqPCo09OGhlDybyaRcwsUlHoKvq3Kq6uaoOVdUlG9z/n6vqpqq6oareUFVfNGU9cFQur7u/MAMwncPPuZ53WUCTBeiqOjbJy5Ocl+TMJBdW1Znrmr0tyf7ufkSS30ryP6aqByblSR5gPCGZJTXlFI6zkxzq7luSpKquSHJ+kpsON+jua2bavyXJt0xYD8CR+eh/eegrYBdNOYXjpCTvnVm+dVi3mWcl+YMJ6wEAgKO2EAcRVtW3JNmf5N9ucv9FSS5KklNPPXUHKwMmYfSQVeV/G/aEKUegb0tyyszyycO6u6mqJyT54SRP7e5PbLSh7r60u/d39/59+/ZNUiwAAMxjygB9bZIzqur0qrpPkguSHJhtUFWPTPJLWQvPt09YCwAAbIvJAnR3fzrJxUmuTvLOJFd2941V9eKqeurQ7KeTPCDJb1bV9VV1YJPNAQDAQph0DnR3X5XkqnXrXjhz+wlTPj4AAGw3l/IGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARjtvtAtha7XYBAMCedDiD9K5WsXiMQAMAwAgCNAAAjGAKxy7xkQgAsGr2yrRTAZrRhP/VMPskpy9h+dmnV4N+XA6mcCygyt55BwcALBc5xQj0vbLX/2kWlXftAMzLa8ZiWpZ+EaD3mKnD/7L846+CKftSP05jL01/8j80jZ16Dtdn05p6/zDQNz1TOAAAYAQj0CtgyhGDnRpFMjK+Gv0ILJapR5QXfcR6J19bduJx9mo/LiIBGjhqO/Vx5DI8ua/aC/bU21/0Pt3Jj9oX/W8BfJYAveJ2cx6UEdXtsyxhA1g9e2k+7U6N9O6GvdSPO0GAhg0YUV0Ny9CPO/VG87BF/Qh4GfqKrRmxZ68QoNkVi/BO2JyynbHsL6j6cff5H1ode+GTSf29NwjQMNKiPEnvtJ0czdwJq9KP2/FmdBHe0G6HVQkuq7avjbUq/chqE6DhKHii3z67+bfUj8ttt94ArMqbsEWxKm/k2BsEaPa8nXwR9AIxLUGYMeyP09qp/VE/shsEaGDHrErAneIFexn/NlMFl2X8W2xEsIPVJUCzVLwg7S36m0WyKsGerZmew1YEaDYkuKwG/QgkAuGq0I+L45jdLgAAAJaJAA0AACMI0AAAMIIADQAAI0waoKvq3Kq6uaoOVdUlG9z/OVX1uuH+t1bVaVPWAwAAR2uyAF1VxyZ5eZLzkpyZ5MKqOnNds2cl+VB3/6skP5fkp6aqBwAAtsOUI9BnJznU3bd09yeTXJHk/HVtzk/y6uH2byV5fFXtqTNvVZxqDABYXpW9l2emDNAnJXnvzPKtw7oN23T3p5PcmeRfTFgTWxizA+zkDjP6cS6vta896oh/rw3+Nov0xLeKT8Sr9vuMtWq//6L9PhvWc4TnwEWrf8eNfH1Y1Ne5Pd+Pu6y6pzkVd1U9Lcm53f2fhuVvTfLo7r54ps07hja3Dst/PbR5/7ptXZTkomHxYUlunqTo+ZyY5P1btmIZ6dvVpF9Xl75dXfp2dS1b335Rd+9bv3LKKxHeluSUmeWTh3Ubtbm1qo5L8qAkH1i/oe6+NMmlE9U5SlUd7O79u10H20/frib9urr07erSt6trVfp2yikc1yY5o6pOr6r7JLkgyYF1bQ4k+fbh9tOS/ElPNSQOAADbYLIR6O7+dFVdnOTqJMcmuay7b6yqFyc52N0HkvzPJL9eVYeSfDBrIRsAABbWlFM40t1XJblq3boXztz+eJL/MGUNE1iIqSRMQt+uJv26uvTt6tK3q2sl+naygwgBAGAVuZQ3AACMIEDPaavLkrNcquo9VfX2qrq+qg4O6z6/qv6oqv5q+P55u10nW6uqy6rq9uG0mIfXbdiXtealw358Q1U9avcqZyub9O2Lquq2Yd+9vqqeNHPfDw19e3NVff3uVM08quqUqrqmqm6qqhur6vuG9fbdJXaEfl25/VaAnsOclyVn+Tyuu8+aOZ3OJUne0N1nJHnDsMzie1WSc9et26wvz0tyxvB1UZJX7lCN3Duvyj37Nkl+bth3zxqOtcnwnHxBki8ffuYVw3M3i+nTSV7Q3WcmeUyS5w19aN9dbpv1a7Ji+60APZ95LkvO8pu9tPyrk3zj7pXCvLr7TVk7i8+szfry/CS/1mvekuTBVfUFO1Ioo23St5s5P8kV3f2J7n53kkNZe+5mAXX3+7r7/w63P5rknVm7OrF9d4kdoV83s7T7rQA9n3kuS85y6SR/WFXXDVe6TJKHdvf7htt/l+Shu1Ma22CzvrQvr4aLh4/xL5uZaqVvl1RVnZbkkUneGvvuyljXr8mK7bcCNHvVV3f3o7L2seDzquprZu8cLujjFDUrQF+unFcm+ZIkZyV5X5Kf2dVqOCpV9YAkv53k+d39kdn77LvLa4N+Xbn9VoCezzyXJWeJdPdtw/fbk/xu1j4y+vvDHwkO32/fvQo5Spv1pX15yXX333f3Z7r7n5P8cj77ca++XTJVdXzWQtZruvt3htX23SW3Ub+u4n4rQM9nnsuSsySq6v5VdcLh20memOQduful5b89yf/anQrZBpv15YEk3zYc0f+YJHfOfFzMElg37/XfZW3fTdb69oKq+pyqOj1rB5v9+U7Xx3yqqrJ2NeJ3dvfPztxl311im/XrKu63k16JcFVsdlnyXS6Le++hSX53bT/PcUku7+7XV9W1Sa6sqmcl+ZskT9/FGplTVb02yTlJTqyqW5P8aJL/no378qokT8ragSofS/IdO14wc9ukb8+pqrOy9tH+e5I8J0m6+8aqujLJTVk7E8Dzuvszu1A28/mqJN+a5O1Vdf2w7r/GvrvsNuvXC1dtv3UlQgAAGMEUDgAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYYqar+Yd3yM6vqF3arnqNVVS+qqh+4lz/7zKr6wu1qB7AMBGiABVdVc5+zv6qOnbKWDTwzyTzBeN52AAtPgAbYJlV1QlW9e7iUbarqgYeXq+qNVfXzVXV9Vb2jqs4e2ty/qi6rqj+vqrdV1fnD+mdW1YGq+pMkb6iq+1XVlVV1U1X9blW9tar2D23/oap+pqr+Isljq+qFVXXt8DiXDlcHy2Y1DM4c7r+lqr53g9/t2Kp61fBzb6+q76+qpyXZn+Q1wzY/d6PH3qTdV1bVn1bVdVV19czlm793+B1vqKorJussgKPgSoQA433uzFW2kuTzkxzo7o9W1RuTfEOS30tyQZLf6e5PDRn2ft19VlV9TZLLkjw8yQ8n+ZPu/s6qenCSP6+qPx62+6gkj+juDw5TLD7U3WdW1cOTzD7+/ZO8tbtfkCRVdVN3v3i4/etJnpzkfw9tN6ohSb4syeOSnJDk5qp6ZXd/auYxzkpyUnc/fNjug7v7w8NVWn+guw8O639h/WN392/NthveYLwsyfndfUdVPSPJTyb5ziSXJDm9uz8x/D0AFo4RaIDx/qm7zzr8leSFM/f9Sj57meHvSPKrM/e9Nkm6+01JHjgExCcmuWQI5G9Mct8kpw7t/6i7Pzjc/uokVww//44kN8xs9zNJfntm+XHDCPXbk3xtki/fooYk+f3u/kR3vz/J7Vm75P2sW5J8cVW9rKrOTfKRjf80R3zswx6WteD+R8Pv/SNJTh7uuyFrI9XfkrVL+wIsHCPQANuou99cVadV1TlJjh3C7l13r2+epJJ8U3ffPHtHVT06yT/O+bAf7+7PDD933ySvSLK/u99bVS/KWig/Ug1J8omZdZ/JuteH7v5QVf2bJF+f5LuSPD1rI8azNW/12Hc1TXJjdz92g/u+IcnXJHlKkh+uqq/obkEaWChGoAG2368luTx3H31OkmckSVV9dZI7u/vOJFcn+Z6ZecqP3GSbb85aaE1VnZnkKzZpdziwvr+qHpDkaXPUsKWqOjHJMd3921kbMX7UcNdHszbtY6vHnm13c5J9VfXYYdvHV9WXV9UxSU7p7muS/GCSByV5wDz1AewkI9AA2+81SX4iw3SJGR+vqrclOT6fHb398SQvSXLDECDfnbU5y+u9Ismrq+qmJH+Z5MYk9wi/w7zkX07yjiR/l+TaOWqYx0lJfnWoMUl+aPj+qiS/WFX/lOSxSTZ77PXtnpbkpVX1oKy9Fr0kybuS/MawrpK8tLs/PKJGgB1R3es/zQPgaAxnnTi/u791Zt0bM3Ow3b3Y5rFJju/uj1fVlyT54yQP6+5PjtjGUdUAwBoj0ADbqKpeluS8JE/a5k3fL8k1wxksKslzx4RnALaPEWgAABjBQYQAADCCAA0AACMI0AAAMIIADQAAIwjQAAAwggANAAAj/H8fPT43QSgdLAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "It spends 4.56 minutes.\n",
      "END\n"
     ]
    }
   ],
   "source": [
    "def run():\n",
    "    t1 = time.time()\n",
    "    # 使用梯度下降法\n",
    "    best_pr = train_gd()\n",
    "    validate_and_visualize(best_pr)\n",
    "    t2 = time.time()\n",
    "    print(\"It spends {:.2f} minutes.\\nEND\".format((t2 - t1)/60))\n",
    "\n",
    "\n",
    "run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5da97ca5-1376-4513-bb55-7503313011ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finish preparing data, with 126 items.\n",
      "Begin training, it may take several minutes on cpu.\n",
      "Finish training, the best parameters are:\n",
      "[ 1.52680435  1.62056776 -0.03166083  1.53614936  0.01427009 -0.31209835\n",
      "  0.09543332  0.01496778  0.29417818].\n",
      "\n",
      "The result image has been save at: fig5.png\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAtAAAAE9CAYAAAAiZVVdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaiElEQVR4nO3dfbRld1kf8O/DJIiYAGoGqkkw0QZsQBpwVoTqokEUEwTGVoqJ9QWlBIWoWOwyVhdStMtal4ogoNFGUAkh9a3TZTRqDOJiFcykxECCwTGgmRTNACGgSCD49I+7Jxxu7r1z9szd95575vNZ66x79t6/s/dz9+/uc753n/1S3R0AAGA+D9juAgAAYCcRoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGCEE7a7gLFOOeWUPuOMM7a7DAAAltwNN9zwge7evXr8jgvQZ5xxRvbv37/dZQAAsOSq6q/XGu8QDgAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGGGyAF1Vl1fVnVX1rnWmV1W9sqoOVNVNVfWEqWoBAIDNMuUe6NclOX+D6RckOWt4XJzktRPWAgAAm2KyAN3db0nyoQ2a7E3yq73ibUkeVlVfMFU9AACwGbbzGOhTk9w+M3xwGAcAAAtrR5xEWFUXV9X+qtp/6NCh7S5ntBoeAADHm2XMQdsZoO9IcvrM8GnDuPvp7su6e09379m9e/eWFAcAAGvZzgC9L8m3DVfjeGKSu7v7/dtYDwAAHNEJU824qt6Y5Lwkp1TVwSQ/muTEJOnuX0hydZKnJzmQ5GNJvmOqWgAAYLNMFqC7+6IjTO8kL5pq+QAAMIUdcRIhAAAsCgEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBghEkDdFWdX1W3VtWBqrp0jemPrKrrquodVXVTVT19ynoAAOBYTRagq2pXklcnuSDJ2UkuqqqzVzX7kSRXdffjk1yY5DVT1QMAAJthyj3Q5yY50N23dfcnklyZZO+qNp3kIcPzhyb5fxPWAwAAx2zKAH1qkttnhg8O42a9LMm3VNXBJFcn+Z61ZlRVF1fV/qraf+jQoSlqBQCAuWz3SYQXJXldd5+W5OlJfq2q7ldTd1/W3Xu6e8/u3bu3vEgAADhsygB9R5LTZ4ZPG8bNel6Sq5Kku/9PkgclOWXCmgAA4JhMGaCvT3JWVZ1ZVQ/MykmC+1a1+ZskT02SqvoXWQnQjtEAAGBhTRagu/veJJckuSbJu7NytY2bq+rlVfWsodlLkjy/qv48yRuTPLe7e6qaAADgWJ0w5cy7++qsnBw4O+6lM89vSfKVU9YAAACbabtPIgQAgB1FgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGEGABgCAEQRoAAAYQYAGAIARBGgAABhhrgBdVadW1b+qqicffsz5uvOr6taqOlBVl67T5jlVdUtV3VxVV4wpHgAAttoJR2pQVT+Z5JuS3JLkU8PoTvKWI7xuV5JXJ/naJAeTXF9V+7r7lpk2ZyX5oSRf2d13VdXDj+q3AACALXLEAJ3kG5I8urvvGTnvc5Mc6O7bkqSqrkyyNytB/LDnJ3l1d9+VJN1958hlAADAlprnEI7bkpx4FPM+NcntM8MHh3GzHpXkUVX11qp6W1Wdv9aMquriqtpfVfsPHTp0FKUAAMDmmGcP9MeS3FhV1ya5by90d3/vJi3/rCTnJTktyVuq6su6+8Ozjbr7siSXJcmePXt6E5YLAABHZZ4AvW94jHVHktNnhk8bxs06mOTt3f3JJO+tqvdkJVBffxTLAwCAyR0xQHf366vqgVk53CJJbh0C75Fcn+SsqjozK8H5wiTfvKrN7yS5KMmvVNUpwzJum7N2AADYcvNcheO8JK9P8r4kleT0qvr27t7wKhzdfW9VXZLkmiS7klze3TdX1cuT7O/ufcO0p1XV4St8/Kfu/uAx/D4AADCp6t74kOKquiHJN3f3rcPwo5K8sbu/fAvqu589e/b0/v37t2PRc6nhZx9hHADA8WB1DqqZaYuejarqhu7es3r8PFfhOPFweE6S7n5Pju6qHMe9mnkAACyrZc8885xEuL+qfjnJrw/D/z7J4u4CBlgy27m3ZiftKQLYKvME6O9O8qIkhy9b96dJXjNZRQAAsMDmuQrHPUl+ZngAAMBxbd0AXVVXdfdzquqdWeObu+5+3KSV7QC+2gQAOHY77YILG+2B/r7h5zO2ohAAANgJ1r0KR3e/f3j6wu7+69lHkhduTXmstplntC7z2bEAcLyREbbOPJex+9o1xl2w2YUAAMBOsNEx0N+dlT3NX1xVN81MOjnJW6cujOPXTjsOartZX8CxcD7PeN532egY6CuS/F6Sn0hy6cz4j3b3hyatCgAAFtS6Abq7705yd5KLkqSqHp7kQUlOqqqTuvtvtqZEAABYHEc8BrqqnllVf5nkvUn+JMn7srJnGgAAjjvznET440memOQ93X1mkqcmedukVQEAwIKaJ0B/srs/mOQBVfWA7r4uyZ6J6wIAgIV0xFt5J/lwVZ2U5C1J3lBVdyb5h2nLAgCAxTTPHui9ST6W5PuT/H6Sv0ryzCmLAgCARTXPHugXJHlTd9+R5PUT1wMAAAttnj3QJyf5g6r606q6pKoeMXVRAACwqI4YoLv7v3T3Y5K8KMkXJPmTqvqjySsDAIAFNM8hHIfdmeRvk3wwycOnKWfBvfjFyY033jd43RpN5h13tBZ1XptpUetaVNbX8tvOPvb3tfz08XiLus4WKSNs9PpRWemcc5JXvOIYq9l889xI5YVV9eYk1yb5/CTP7+7HTV0YAAAsourujRtU/URWTiK8cUsqOoI9e/b0/v37t7uMJEnNPO9V43qddqvbH+0yj/b1U81rMy1qXYvK+lp+a73XHA/LZmvo4/EW9X13kTLCRtln3vy0CKrqhu6+3/1P1j2Eo6oe0t0fSfJTw/DnzU7v7g9tepUAALDgNjoG+ookz0hyQ1b+IVj9D8MXT1gXAAAspHUDdHc/Y/h55taVAwAAi22ekwivnWccAAAcDzY6BvpBSR6c5JSq+tx8+hCOhyQ5dQtqAwCAhbPRMdAvSPLiJF+YleOgDwfojyT5+WnLAgCAxbTRMdA/l+Tnqup7uvtVW1gTAAAsrCMeA53kn6rqYYcHqupzq+qF05UEAACLa54A/fzu/vDhge6+K8nzJ6sIAAAW2DwBeldV3XcN6KraleSB05UEAACLa6OTCA/7/SRvqqpfHIZfkOT3pisJAAAW1zwB+geTXJzku4bhm5L8s8kqAgCABXbEQzi6+5+SvD3J+5Kcm+Srk7x72rIAAGAxbXQjlUcluWh4fCDJm5Kku5+yNaUBAMDi2egQjr9I8qdJntHdB5Kkqr5/S6oCAIAFtdEhHP82yfuTXFdVv1RVT82n70YIAADHpXUDdHf/TndfmORLk1yXldt6P7yqXltVT9ui+gAAYKHMcxLhP3T3Fd39zCSnJXlHVq7MAQAAx515bqRyn+6+q7sv6+6nTlUQAAAsslEBGgAAjneTBuiqOr+qbq2qA1V16QbtvrGquqr2TFkPAAAcq8kCdFXtSvLqJBckOTvJRVV19hrtTk7yfVm5WQsAACy0KfdAn5vkQHff1t2fSHJlkr1rtPuxJD+Z5OMT1gIAAJtiygB9apLbZ4YPDuPuU1VPSHJ6d//uhHUAAMCm2baTCKvqAUl+JslL5mh7cVXtr6r9hw4dmr44AABYx5QB+o4kp88MnzaMO+zkJI9N8uaqel+SJybZt9aJhMOl8/Z0957du3dPWDIAAGxsygB9fZKzqurMqnpgkguT7Ds8sbvv7u5TuvuM7j4jyduSPKu7909YEwAAHJPJAnR335vkkiTXJHl3kqu6++aqenlVPWuq5QIAwJROmHLm3X11kqtXjXvpOm3Pm7IWAADYDO5ECAAAIwjQAAAwggANAAAjCNAAADCCAA0AACMI0AAAMIIADQAAIwjQAAAwggANAAAjCNAAADCCAA0AACMI0AAAMIIADQAAIwjQAAAwggANAAAjCNAAADCCAA0AACMI0AAAMIIADQAAIwjQAAAwggANAAAjCNAAADCCAA0AACMI0AAAMIIADQAAIwjQAAAwggANAAAjCNAAADDCCdtdwNK6oj79/Jt7++oAAFgEh7PREuQiARpgUW3wYXP4X/TN/Bia+bd/7fku0YcfwLFwCAcAAIwgQI9xRX3moRkAAExjgXOXQzg2yxRfbW7iPKf4uhdYLGO3c+8LsNyOeFjWGFOd27VDDw0ToHeqTfxD3tQN7CiXvbM2m23k5NRpLcl2tah2xIc5K6zfURbhc3RT6Pe5CdAsLAFkHOsLGGMxvxhfXNYXswToJSRIbYMd+hXUtrGXg2Xlb3sc62vL+eZ3cwjQi8KbCACwkxzHO48EaDbHPBuRfxJG8U0CsEjsuRxp3nB5HIfQnUyAZqlt6hv+cfAmZ33BhJykOo71xQIToIHji2C//HzbBUzMjVQAAGAEARoAAEYQoAEAYIRJA3RVnV9Vt1bVgaq6dI3p/7Gqbqmqm6rq2qr6oinrAQCAYzVZgK6qXUleneSCJGcnuaiqzl7V7B1J9nT345L8RpL/PlU9AACwGabcA31ukgPdfVt3fyLJlUn2zjbo7uu6+2PD4NuSnDZhPQAAcMymDNCnJrl9ZvjgMG49z0vyexPWAwAAx2whrgNdVd+SZE+Sf73O9IuTXJwkj3zkI7ewMgAA+ExT7oG+I8npM8OnDeM+Q1V9TZIfTvKs7r5nrRl192Xdvae79+zevXuSYgEAYB5TBujrk5xVVWdW1QOTXJhk32yDqnp8kl/MSni+c8JaAABgU0wWoLv73iSXJLkmybuTXNXdN1fVy6vqWUOzn0pyUpL/WVU3VtW+dWYHAAALYdJjoLv76iRXrxr30pnnXzPl8gEAYLO5EyEAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgCNAAAjCBAAwDACAI0AACMIEADAMAIAjQAAIwgQAMAwAgnbHcBAMyvlmw5ADuRPdAAADCCPdBHYav3zGzm8uxVgnEObzO9rVWsb6Nteq1pa/0ei/6+MFvfovYDLJIpcsNWbXs7ZXsXoFkIi/4BfjSmftNZtnW2nUF1M5c9T6Bd5A+FzXKkD8Ep1vlWr9ed8kE/xla9by3b+uL4I0Bvk43eRMa+wcwzr800b33L9ka5lmUIAVtpWX5H/X5/1sn9LcvvsRH9Ps48v+NU/5ht5j/3W507FpEAzaaad8M52g1s3o129fSxb0jzzmuso/2AmOJ1y7K+lnEvIFtr3r+hrd4ON5rXkaYd6zY97zLned3xur68by03AZo1zXvs5E52POzt2Ew7cX1NFexZTJv5zd4iELLGWeb1dTx8Ju80AjQ70ma+UW7Hm+5OPiljkdfXTgxJbJ95/onajn+0dvKJ49YXxwsBGuAoHe/HfwouwPHKdaABAGAEARoAAEYQoAEAYAQBGgAARhCgAQBgBAEaAABGEKABAGAEARoAAEaYNEBX1flVdWtVHaiqS9eY/llV9aZh+tur6owp6wEAgGM1WYCuql1JXp3kgiRnJ7moqs5e1ex5Se7q7n+e5GeT/ORU9QAAwGaYcg/0uUkOdPdt3f2JJFcm2buqzd4krx+e/0aSp1aVO7oCALCwpgzQpya5fWb44DBuzTbdfW+Su5N8/oQ1AQDAManunmbGVc9Ocn53/4dh+FuTfEV3XzLT5l1Dm4PD8F8NbT6wal4XJ7l4GHx0klsnKXo+pyT5wBFbsRPp2+WkX5eXvl1e+nZ57bS+/aLu3r165AkTLvCOJKfPDJ82jFurzcGqOiHJQ5N8cPWMuvuyJJdNVOcoVbW/u/dsdx1sPn27nPTr8tK3y0vfLq9l6dspD+G4PslZVXVmVT0wyYVJ9q1qsy/Jtw/Pn53kj3uqXeIAALAJJtsD3d33VtUlSa5JsivJ5d19c1W9PMn+7t6X5H8k+bWqOpDkQ1kJ2QAAsLCmPIQj3X11kqtXjXvpzPOPJ/l3U9YwgYU4lIRJ6NvlpF+Xl75dXvp2eS1F3052EiEAACwjt/IGAIARBOg5Hem25OwsVfW+qnpnVd1YVfuHcZ9XVX9YVX85/Pzc7a6TI6uqy6vqzuGymIfHrdmXteKVw3Z8U1U9Yfsq50jW6duXVdUdw7Z7Y1U9fWbaDw19e2tVfd32VM08qur0qrquqm6pqpur6vuG8bbdHWyDfl267VaAnsOctyVn53lKd58zczmdS5Nc291nJbl2GGbxvS7J+avGrdeXFyQ5a3hcnOS1W1QjR+d1uX/fJsnPDtvuOcO5Nhneky9M8pjhNa8Z3rtZTPcmeUl3n53kiUleNPShbXdnW69fkyXbbgXo+cxzW3J2vtlby78+yTdsXynMq7vfkpWr+Mxary/3JvnVXvG2JA+rqi/YkkIZbZ2+Xc/eJFd29z3d/d4kB7Ly3s0C6u73d/f/HZ5/NMm7s3J3YtvuDrZBv65nx263AvR85rktOTtLJ/mDqrphuNNlkjyiu98/PP/bJI/YntLYBOv1pW15OVwyfI1/+cyhVvp2h6qqM5I8PsnbY9tdGqv6NVmy7VaA5nj1Vd39hKx8Lfiiqnry7MThhj4uUbME9OXSeW2SL0lyTpL3J/npba2GY1JVJyX5zSQv7u6PzE6z7e5ca/Tr0m23AvR85rktOTtId98x/LwzyW9n5Sujvzv8leDw887tq5BjtF5f2pZ3uO7+u+7+VHf/U5Jfyqe/7tW3O0xVnZiVkPWG7v6tYbRtd4dbq1+XcbsVoOczz23J2SGq6nOq6uTDz5M8Lcm78pm3lv/2JP9reypkE6zXl/uSfNtwRv8Tk9w983UxO8Cq417/TVa23WSlby+sqs+qqjOzcrLZn211fcynqiordyN+d3f/zMwk2+4Otl6/LuN2O+mdCJfFercl3+ayOHqPSPLbK9t5TkhyRXf/flVdn+Sqqnpekr9O8pxtrJE5VdUbk5yX5JSqOpjkR5P8t6zdl1cneXpWTlT5WJLv2PKCmds6fXteVZ2Tla/235fkBUnS3TdX1VVJbsnKlQBe1N2f2oaymc9XJvnWJO+sqhuHcf85tt2dbr1+vWjZtlt3IgQAgBEcwgEAACMI0AAAMIIADQAAIwjQAAAwggANAAAjCNAAI1XV368afm5V/fx21XOsquplVfUDR/na51bVF25WO4CdQIAGWHBVNfc1+6tq15S1rOG5SeYJxvO2A1h4AjTAJqmqk6vqvcOtbFNVDzk8XFVvrqqfq6obq+pdVXXu0OZzquryqvqzqnpHVe0dxj+3qvZV1R8nubaqHlxVV1XVLVX121X19qraM7T9+6r66ar68yRPqqqXVtX1w3IuG+4OlvVqGJw9TL+tqr53jd9tV1W9bnjdO6vq+6vq2Un2JHnDMM/PXmvZ67T78qr6k6q6oaqumbl98/cOv+NNVXXlZJ0FcAzciRBgvM+euctWknxekn3d/dGqenOSr0/yO0kuTPJb3f3JIcM+uLvPqaonJ7k8yWOT/HCSP+7u76yqhyX5s6r6o2G+T0jyuO7+0HCIxV3dfXZVPTbJ7PI/J8nbu/slSVJVt3T3y4fnv5bkGUn+99B2rRqS5EuTPCXJyUlurarXdvcnZ5ZxTpJTu/uxw3wf1t0fHu7S+gPdvX8Y//Orl93dvzHbbvgH41VJ9nb3oar6piT/Ncl3Jrk0yZndfc+wPgAWjj3QAOP9Y3efc/iR5KUz0345n77N8Hck+ZWZaW9Mku5+S5KHDAHxaUkuHQL5m5M8KMkjh/Z/2N0fGp5/VZIrh9e/K8lNM/P9VJLfnBl+yrCH+p1JvjrJY45QQ5L8bnff090fSHJnVm55P+u2JF9cVa+qqvOTfGTtVbPhsg97dFaC+x8Ov/ePJDltmHZTVvZUf0tWbu0LsHDsgQbYRN391qo6o6rOS7JrCLv3TV7dPEkl+cbuvnV2QlV9RZJ/mHOxH+/uTw2ve1CS1yTZ0923V9XLshLKN6ohSe6ZGfeprPp86O67qupfJvm6JN+V5DlZ2WM8W/ORln1f0yQ3d/eT1pj29UmenOSZSX64qr6suwVpYKHYAw2w+X41yRX5zL3PSfJNSVJVX5Xk7u6+O8k1Sb5n5jjlx68zz7dmJbSmqs5O8mXrtDscWD9QVSclefYcNRxRVZ2S5AHd/ZtZ2WP8hGHSR7Ny2MeRlj3b7tYku6vqScO8T6yqx1TVA5Kc3t3XJfnBJA9NctI89QFsJXugATbfG5L8eIbDJWZ8vKrekeTEfHrv7Y8leUWSm4YA+d6sHLO82muSvL6qbknyF0luTnK/8Dscl/xLSd6V5G+TXD9HDfM4NcmvDDUmyQ8NP1+X5Beq6h+TPCnJeste3e7ZSV5ZVQ/NymfRK5K8J8mvD+MqySu7+8MjagTYEtW9+ts8AI7FcNWJvd39rTPj3pyZk+2OYp67kpzY3R+vqi9J8kdJHt3dnxgxj2OqAYAV9kADbKKqelWSC5I8fZNn/eAk1w1XsKgkLxwTngHYPPZAAwDACE4iBACAEQRoAAAYQYAGAIARBGgAABhBgAYAgBEEaAAAGOH/A/80OmIml0iTAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "It spends 8.17 minutes.\n",
      "END\n"
     ]
    }
   ],
   "source": [
    "def run2():\n",
    "    t1 = time.time()\n",
    "    # 使用 cobyla 优化器\n",
    "    best_pr = train_cobyla()\n",
    "    validate_and_visualize(best_pr)\n",
    "    t2 = time.time()\n",
    "    print(\"It spends {:.2f} minutes.\\nEND\".format((t2 - t1)/60))\n",
    "\n",
    "\n",
    "run2()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0a32975-ac37-4705-b79a-6ddd61d54a65",
   "metadata": {},
   "source": [
    "验证结果表明：使用 COBYLA 优化器对识别出的纠缠态具有较大激活值（Activation）并数值差异较小，较为稳定，更具鲁棒性。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
